import numpy as np
import pandas as pd
from scipy.stats import truncnorm
from scipy.optimize import minimize
from functools import partial
from typing import Dict, List, Tuple
import sys, time, random, os

# --- Environment Setup (Keep outside the loop)

new_directory = 'C:\\Users\\xt9\\Box\\Python Codes'
os.chdir(new_directory)

SEED_VALUE = 42
random.seed(SEED_VALUE)
np.random.seed(SEED_VALUE)

# --- Function Definitions (Keep outside the loop)

# calculate link formation probability
def p(xi, xj, design, prob):
    # Specify link formation probability function using boolean masks 
    # Input:  
    # - xi ~ column vector; xj ~ row vector
    # - prob: vector of link prob; design: specify support of x
    # Output:
    # - p_matrix: (n, n) array of link probabilities.
    if design == 11 or design == 12:
        mask_00 = (xi == 0) & (xj == 0)
        mask_01 = ((xi == 0) & (xj == 1)) | ((xi == 1) & (xj == 0))
        mask_11 = (xi == 1) & (xj == 1)
        p_matrix = np.zeros((xi.shape[0], xj.shape[1]), dtype=float)
        p_matrix[mask_00] = prob[0] 
        p_matrix[mask_01] = prob[1]
        p_matrix[mask_11] = prob[2]

    elif design == 2:
        # support of x is {0.0, 1.0, 2.0}.
        # prob: (6,) vector of probabilities [p(0,0), p(0,1), p(0,2), p(1,1), p(1,2), p(2,2)]
        p_matrix = np.zeros((xi.shape[0],xj.shape[1]), dtype=float)
        mask_00 = (xi == 0) & (xj == 0)      # p(0, 0) -> prob_vec[0]
        mask_11 = (xi == 1) & (xj == 1)      # p(1, 1) -> prob_vec[3]
        mask_22 = (xi == 2) & (xj == 2)      # p(2, 2) -> prob_vec[5]
        mask_01 = ((xi == 0) & (xj == 1)) | ((xi == 1) & (xj == 0)) # p(0, 1) -> prob_vec[1]
        mask_02 = ((xi == 0) & (xj == 2)) | ((xi == 2) & (xj == 0)) # p(0, 2) -> prob_vec[2]
        mask_12 = ((xi == 1) & (xj == 2)) | ((xi == 2) & (xj == 1)) # p(1, 2) -> prob_vec[4]
        p_matrix[mask_00] = prob[0]
        p_matrix[mask_01] = prob[1]
        p_matrix[mask_02] = prob[2]
        p_matrix[mask_11] = prob[3] # Note the index jump here
        p_matrix[mask_12] = prob[4]
        p_matrix[mask_22] = prob[5]
    
    return p_matrix

# calculate bin boundaries (cutoffs) for discretizing m_raw to m
def discretize_m(x_matrix: np.ndarray, C: int, rel_freq: np.array = 0.10):
    """
    Calculates the C-1 bin boundaries for transforming X to Y.
    
    The transformation is defined as:
    - Y = 1 if X <= b_1
    - Y = 2 if b_1 < X <= b_2
    - ...
    - Y = C if X > b_{C-1}
    
    Constraints:
    1. Freq(Y=1) >= 5% and Freq(Y=C) >= 5%
    2. Intermediate bins (Y=2...C-1) have roughly equal frequencies.
    
    Args:
        x_matrix: The n-by-s matrix of x_i values (integers 0 to B).
        C: The target number of bins {1, ..., C}.
        B: The maximum possible value of x_i (e.g., 30).

    Returns:
        A list of C-1 integer boundaries [b_1, b_2, ..., b_{C-1}].
        A transformed n-by-s matrix y_matrix with discretized values in {1, ..., C}.
    """
    
    if C < 2:
        raise ValueError("C must be at least 2.")
        
    x_flat = x_matrix.flatten().astype(np.int64)
    N = len(x_flat)
    B = x_flat.max()
    
    if N == 0:
        raise ValueError("Input matrix is empty.")
        
    # 1. Get counts and cumulative counts for values 0 to B
    # np.bincount is highly efficient for this
    counts = np.bincount(x_flat, minlength=B + 1)
    cum_counts = np.cumsum(counts)
    
    # 2. Find boundary b_1 (for Y=1)
    # We need Freq(X <= b_1) >= 5%
    min_count_1 = rel_freq * N
    # Find the first x-value (b1) where the cumulative count meets this
    # np.searchsorted finds the insertion index to maintain order
    b1 = np.searchsorted(cum_counts, min_count_1, side='left')
    b1 = min(b1, B) # Cap at max value B

    # 3. Find boundary b_{C-1} (for Y=C)
    # We need Freq(X > b_{C-1}) >= 5%
    # This means Freq(X <= b_{C-1}) <= 95%
    max_count_C_minus_1 = (1-rel_freq) * N
    # Find the last x-value (b_{C-1}) where cumulative count is below this
    b_C_minus_1 = np.searchsorted(cum_counts, max_count_C_minus_1, side='right') - 1
    b_C_minus_1 = max(b_C_minus_1, 0) # Floor at 0

    # 4. Check for conflict
    if b1 > b_C_minus_1:
        print(f"Warning: Data distribution conflicts with constraints (b1={b1}, b_C_1={b_C_minus_1}). "
              "This can happen with highly concentrated data. "
              "Adjusting boundaries, but bins may be skewed or empty.")
        
        # If 5th percentile is > 95th percentile, data is very sparse or identical
        # We'll just pick a midpoint.
        b_mid = (b1 + b_C_minus_1) // 2
        b1 = b_mid
        b_C_minus_1 = max(b_mid, b1) # ensure b_C_minus_1 >= b1

    # 5. Find intermediate boundaries (b_2, ..., b_{C-2})
    # These will divide the count [cum_counts[b1], cum_counts[b_C_minus_1]]
    # into C-2 equal-frequency bins.
    
    boundaries = [b1]
    num_intermediate_bins = C - 2
    
    if num_intermediate_bins > 0:
        # Get the count range for intermediate bins
        start_count = cum_counts[b1]
        end_count = cum_counts[b_C_minus_1]
        
        # We need to find C-3 cut points (b_2, ..., b_{C-2})
        # We create C-2 bins, so we need C-1 points in linspace
        # e.g., C=4 -> 2 intermediate bins. Need 3 points (start, b2, end)
        
        if end_count <= start_count:
            # No data in the intermediate range.
            # All intermediate boundaries will be same as b1.
            for _ in range(num_intermediate_bins - 1): # C-3 boundaries
                boundaries.append(b1)
        else:
            # Get target counts for each new boundary
            target_counts = np.linspace(start_count, end_count, num_intermediate_bins + 1)
            
            current_b = b1
            # Find b_2, ..., b_{C-2}
            for i in range(1, num_intermediate_bins): # Loop C-3 times
                target_count = target_counts[i]
                b_next = np.searchsorted(cum_counts, target_count, side='left')
                
                # Ensure boundaries are strictly increasing and valid
                if b_next <= current_b:
                    b_next = current_b + 1
                    
                b_next = min(b_next, b_C_minus_1) # Don't go past the last boundary
                boundaries.append(b_next)
                current_b = b_next

    boundaries.append(b_C_minus_1)
    
    # 6. Clean up boundaries: ensure C-1 unique, increasing values
    clean_boundaries = []
    last_b = -1
    for b in boundaries:
        if b <= last_b:
            b = last_b + 1
        
        b = min(b, B) # Don't go past B
        clean_boundaries.append(b)
        last_b = b
        
    # Ensure we have C-1 boundaries
    while len(clean_boundaries) < C - 1:
        clean_boundaries.append(min(clean_boundaries[-1] + 1, B))
    
    # Final trim if we overshot
    clean_boundaries = clean_boundaries[:C-1]

    # return clean_boundaries
    cutoffs_arr = np.array(clean_boundaries)
    # np.searchsorted is vectorized and works on both scalars and arrays.
    # We add +1 to get bins {1, ..., C}
    y_matrix = np.digitize(x_matrix, cutoffs_arr, right=True) + 1
    return y_matrix, clean_boundaries

# generate X, G, M_raw, given n, S, Lm, design
def generate_XMG(n: int, S: int, Lm: int, design: int, Sx: np.ndarray, prob: np.ndarray):

    X = np.zeros((n, S))
    G = np.zeros((n, n, S), dtype=bool) # use boolean for memory/speed
    M = np.zeros((n, S))
    M_raw = np.zeros((n, S))   # store raw m before discretization (only for definition 1)
    Sm = np.arange(1, Lm + 1, 1)  # support of m after discretization  

    for s in range(S):

        x   = np.random.choice(Sx, size=n)
        X_i = x[:,np.newaxis]  # turn x into column vector
        X_j = x[np.newaxis,:]  # turn x into row vector
        P_mat = p(X_i, X_j, design, prob) # access the n-by-n matrix of link formation prob
        g     = (np.random.rand(n,n)< P_mat).astype(int)  # generate the adjacency matrix

        # -- calculate average absolute difference --
        m_raw = np.sum( (X_i == X_j)*g, axis=1 )  # n x 1 vector
        # try: 
        #     m = m_midpoints[m_indices]  # replace m[i] with the corresponding midpoint
        # except IndexError:
        #     print("IndexError: m_indices out of bounds")
        #     print(f"m_indices: {m_indices}")
        #     print(f"m_raw: {m_raw}")
        #     print(f"m_midpoints: {m_midpoints}")
        #     pdb.post_mortem(sys.exc_info()[2])  
        
        X[:, s]     = x
        G[:, :, s]  = g # Store as boolean (or int, depending on memory/use)
        M_raw[:,s]  = m_raw  

    # # summarize frequencies of m_raw across all simulations
    # counts_dict = pd.Series(M_raw.flatten()).value_counts().sort_index()
    # print(f"Original Matrix Shape: {M_raw.shape}")
    # print(f"Index (m_raw_i values): {counts_dict.index}")
    # print(f"Values(counts): {counts_dict.values}")

    # # discretize m_raw into m using the discretization function
    # cutoffs = get_cutoffs(M_raw, C=4, rel_freq=0.1)
    # M = discretize_m(M_raw, cutoffs)
    M , cutoffs = discretize_m(M_raw, C=Lm)
    
    # # summarize frequencies of discretized m across all simulations
    # bin_labels, bin_counts = np.unique(M.flatten(), return_counts=True)   
    # print("\nBin frequencies:")
    # for label, count in zip(bin_labels, bin_counts):    
    #     print(f"Y = {label}: Count = {count}, Frequency = {count / (n*S):.4f}") 
    # print("Original M_raw (First 3 rows):")
    # print(M_raw[:3])
    # print("\nDiscretized M_discretized (First 3 rows):")
    # print(M[:3])

    return X, G, M, Sm

# calculate qn(jm, im, ix, jx) using simulated draws
def calculate_qn(X,G,M,Sx,Sm):
    Lm = len(Sm)
    Lx = len(Sx)
    S2 = X.shape[1]
    qn = np.zeros((Lm, Lm, Lx, Lx))
    for ix in range(Lx):
        for jx in range(Lx):
            for im in range(Lm):
                Mj = []
                for s in range(S2):
                    x, g, m = X[:, s], G[:, :, s], M[:, s]
                    
                    condition = (m == Sm[im]) & (x == Sx[ix])
                    if np.sum(condition) == 0:
                        continue
                    possible_i = np.where(condition)[0]

                    for ii in possible_i:

                        N_i = np.where(g[ii, :] == 1)[0]
                        if len(N_i) == 0:
                            continue                           
                    
                        N_i_prime_indices = np.where(x[N_i] == Sx[jx])[0]
                        if len(N_i_prime_indices) == 0:
                            continue
                    
                        for jj in N_i_prime_indices:
                            j_ind = N_i[jj] # get the absolute index
                            Mj.append(m[j_ind])

                if len(Mj) == 0:
                    continue

                Mj = np.array(Mj)
                
                for jm in range(Lm):
                    qn[jm, im, ix, jx] = np.sum(Mj == Sm[jm])
                
                qn[:, im, ix, jx] /= len(Mj)
    return qn

# estimate lambda using sample data
def estimate_lambda(m, x, a, Sx, Sm):
    Lx = len(Sx)
    Lm = len(Sm)
    lam = np.zeros((Lx, Lm))
    for ix in range(Lx):
        for im in range(Lm):
            condition = (m == Sm[im]) & (x == Sx[ix])
            numerator = np.sum(a*condition)
            denominator = np.sum(condition)
            if denominator > 0:
                lam[ix, im] = numerator / denominator
    return lam

# define objective function for NLS
def Gn(gamma,Sx,Sm,Ix,Im,lam,Sigma,Lx):
        beta, delta, tau = gamma
        f = np.mean((lam[Ix, Im] - Sx[Ix] * beta - Sm[Im] * delta - tau * Sigma/Lx)**2)
        return f

# -------------------------------------------------------------------------
#                        Main Simulation Function
# -------------------------------------------------------------------------

def run_simulation(n: int, design: int, case: int, S: int = 200, S2: int = 10000, convergent: int = 1, Lm: int = 5):

    start_time = time.time()

    # ---  PART 0: define parameters and specifications

    # structural parameters and error distribution
    if case == 3:
        beta, delta, tau = 1.5, 0.8, 0.7
        # For scipy.stats.truncnorm, bounds a and b are (lower-mean)/std, (upper-mean)/std
        a, b = (-1.5 - 0) / 1, (1.5 - 0) / 1
        d_epsilon = truncnorm(a=a, b=b, loc=0, scale=1)
    elif case == 1:
        beta, delta, tau = 3.0, 1.5, 0.8
        a, b = (-1.5 - 0) / 1, (1.5 - 0) / 1
        d_epsilon = truncnorm(a=a, b=b, loc=0, scale=3/2) # a, b = 1.5, scale = 1.5
    elif case == 2:
        beta, delta, tau = 3.0, 1.5, 0.8
        a, b = (-3.0 - 0) / 1, (3.0 - 0) / 1
        d_epsilon = truncnorm(a=a, b=b, loc=0, scale=3) # a, b = 3, scale = 3

    # support of x and link formation probabilities ---
    if design == 11: # x ~ Uniform{0,1}
        Sx = np.array([0.0, 1.0])  # Support of x
        Lx = len(Sx)  # Length of the support of x
        if convergent == 0:  # Link formation probability is fixed
            prob = np.array([0.06, 0.04, 0.06])
        elif convergent == 1:  # Link formation probability is converging as n→∞
            mu = np.array([10, 5, 10])  # μ ≡ n⋅p
            prob = mu / n  # p = μ/n as n→∞
    elif design == 12: # x ~ Uniform{0,1}
        Sx = np.array([0.0, 1.0])  # Support of x
        Lx = len(Sx)  # Length of the support of x
        if convergent == 0:  # Link formation probability is fixed
            prob = np.array([0.06, 0.04, 0.06])
        elif convergent == 1:  # Link formation probability is converging as n→∞
            mu = np.array([20, 10, 20])  # μ ≡ n⋅p
            prob = mu / n  # p = μ/n as n→∞
    elif design == 2: # x ~ Uniform{0,1,2}
        Sx = np.array([0.0, 1.0, 2.0]) # Support of x
        Lx = len(Sx) # Length of the support of x
        if convergent == 0:  # Link formation probability is fixed
            prob = np.array([0.06, 0.04, 0.03, 0.06, 0.04, 0.06])
        elif convergent == 1:  # Link formation probability is converging as n→∞
            mu = np.array([10, 8, 5, 12, 8, 10])  # μ ≡ n⋅p
            prob = mu / n  # p = μ/n as n→∞

    # ---  Part 1: solve for asymptotic moments via simulation (using S2 draws of network)

    # simulate S2 networks and store x, g, m
    X, G, M, Sm = generate_XMG(n, S2, Lm, design, Sx, prob)
    
    # calculate qn(jm, im, ix, jx) ---
    qn = calculate_qn(X,G,M,Sx,Sm)

    # calculate  λn ussing contraction mapping algorithm ---
    tol = 1e-6
    maxit = 5000
    lambda_n = np.ones((Lx, Lm))
    for iteration in range(maxit):
        lambda_n_prime = np.zeros((Lx, Lm))
        for ix in range(Lx):
            for im in range(Lm):
                E_lambda_n = 0
                for jx in range(Lx):
                    inner_sum = 0
                    for jm in range(Lm):
                        inner_sum += lambda_n[jx, jm] * qn[jm, im, ix, jx]
                    # E_lambda_n += omega(ix, jx) * inner_sum
                    E_lambda_n += (1/Lx) * inner_sum
                
                lambda_n_prime[ix, im] = Sx[ix] * beta + Sm[im] * delta + tau * E_lambda_n
        
        err = np.max(np.abs(lambda_n - lambda_n_prime))
        lambda_n = lambda_n_prime.copy()
        
        # print(f"Iteration {iteration + 1}: error = {err}")
        if err < tol:
            print(f"Iteration of strategy function reached tolerance after {iteration} rounds.")
            break
        if iteration == maxit - 1:
            print("Maximum number of iterations reached.")

    # --- Part 2: simulate Monte Carlo samples (S networks, each with n individuals)

    # simulate choices in S2 simulated networks (using previously drawn X, G, M)
    Epsilon = d_epsilon.rvs(size=(n,S2))  # draw error terms

    # calculate expected value of a's for each (x,m) pair
    Ea = np.zeros((Lx,Lm)) # initialize expected value of a's
    for ix in range(Lx):
        for im in range(Lm):
            # Ea[ix, im] = 0
            for jx in range(Lx):
                inner_sum = 0
                for jm in range(Lm):
                    inner_sum += lambda_n[jx, jm] * qn[jm, im, ix, jx]
                Ea[ix, im] += (1/Lx) * inner_sum  # equal weights: \omega(ix, jx) = 1/Lx

    # generate choices for each individual in each simulated network
    A = np.zeros((n, S2))
    for ix in range(Lx):
        for im in range(Lm):
            mask = (X == Sx[ix]) & (M == Sm[im])
            A[mask] = X[mask] * beta + M[mask] * delta + tau * Ea[ix,im] + Epsilon[mask]

    # draw random sample of size S from the larger pool of S2 simulations
    sample_index = np.random.choice(S2, size=S, replace=False)
    x_data = X[:, sample_index]
    m_data = M[:, sample_index]
    a_data = A[:, sample_index]
    g_data = G[:, :, sample_index] 

    # --- Part 3: estimate parameters using NLS
    param_est = np.zeros((3,S))  # store parameter estimates from each simulation
    flag_mat = np.zeros((S,), dtype=bool)  # store convergence flag from each simulation
    funval_mat = np.zeros((S,))  # store objective function value from each simulation  
    # bounds = [(beta - 1, beta + 1), (delta - 1, delta + 1), (0.0, 0.999)]

    for s in range(S): # loop over the samples

        # -- estimate transition probabilities qn(.)
        qn_hat_s = calculate_qn(x_data[:,s:s+1], g_data[:,:,s:s+1], m_data[:,s:s+1], Sx, Sm)

        # -- estimate lambda_n
        lambda_hat_s = estimate_lambda(m_data[:,s], x_data[:,s], a_data[:,s], Sx, Sm)

        # -- prepare data for NLS estimation --
        Ix_s = np.empty(n, dtype=int)
        Im_s = np.empty(n, dtype=int)
        Ix_s[:] = np.searchsorted(Sx, x_data[:,s]) # Sx and Sm must be sorted arrays
        Im_s[:] = np.searchsorted(Sm, m_data[:,s])
        qn_subset_s = qn_hat_s[:,Im_s, Ix_s, :]
        qn_indexed_s = np.transpose(qn_subset_s, (1,0,2))  # qn_indexed.shape is (n,Lm,Lx)
        # use advanced indexing to get qn[:, Im[i], Ix[i], jx] for each i in I
        Sigma_s = np.zeros(n)
        for jx in range(Lx):
            for jm in range(Lm):
                Sigma_s += qn_indexed_s[:, jm, jx] * lambda_hat_s[jx, jm]
        
        # define objective function (handle fixed parameters using `partial` from `functools`)
        Gn_s = partial(Gn, Sx=Sx, Sm=Sm, Lx=Lx, Ix=Ix_s, Im=Im_s, lam=lambda_hat_s, Sigma=Sigma_s)

        # -- perform optimization --
        result_s        = minimize(fun = Gn_s, x0 = np.array([beta-1,delta+1,.1]), method='Nelder-Mead', # bounds=bounds, 
                                        options={'xatol': 1e-3, 'maxiter':10000,'fatol':1e-4})
        flag_mat[s]     = result_s.success
        funval_mat[s]   = result_s.fun    
        param_est[:,s]  = result_s.x

    # --- Part 4: calculate performance metrics ---  
    
    # save results and print runtime
    filename = f'output_n{n}_defn2_design{design}_case{case}' 
    fullname = os.path.join(os.getcwd(), 'output', filename)
    bias = np.mean(param_est, axis=1) - np.array([beta, delta, tau])
    var  = np.var(param_est, axis=1)
    mse  = np.mean((param_est - np.array([beta, delta, tau])[:, np.newaxis])**2, axis=1)
    np.savez(fullname, bias=bias, var=var, mse=mse, param_est=param_est, flag_mat=flag_mat, funval_mat=funval_mat)
    # true_params = np.array([beta, delta, tau])
    # true_params_reshaped = true_params[:, np.newaxis]
    # print(f"File saved to directory: {os.getcwd()}")

    end_time = time.time()
    run_time = end_time - start_time
    print(f"The code for defn = 2, n={n}, design={design}, case={case} ran in {run_time:.4f} seconds.")

    # LOG_FILENAME = f'simulation_output_defn2_n{n}_design{design}_case{case}_convergent{convergent}.txt'
    # with open(LOG_FILENAME, 'w') as f:
    #     original_stdout = sys.stdout  # Save a reference to the original standard output
    #     sys.stdout = f                # Change standard output to the file

    #     print(f"Sample Size = {n}   Design = {design}  Convergent = {convergent}")
    #     print(f"Bias = {np.mean(param_est, axis=1)-true_params}")
    #     print(f"Var = {np.var(param_est, axis=1)}")
    #     print(f"MSE = {np.mean((param_est-true_params_reshaped)**2, axis=1)}")
        
    #     sys.stdout = original_stdout

# ----------------------------------------------------------------------
# Outer Loop to Run All Scenarios (The one time manual run)
# ----------------------------------------------------------------------

if __name__ == '__main__':
    # define the scenarios: (n, design, case)
    scenarios = [
        # ( 50, 11, 1), ( 50, 12, 1), ( 50, 2, 1),
        # ( 50, 11, 2), ( 50, 12, 2), ( 50, 2, 2),
        (100, 11, 1), (100, 12, 1), (100, 2, 1),
        (100, 11, 2), (100, 12, 2), (100, 2, 2),
        (200, 11, 1), (200, 12, 1), (200, 2, 1),
        (200, 11, 2), (200, 12, 2), (200, 2, 2),
        (400, 11, 1), (400, 12, 1), (400, 2, 1),
        (400, 11, 2), (400, 12, 2), (400, 2, 2)    ]

    # print(f"Starting simulation of Defn 2, {len(scenarios)} scenarios...")
    
    for n_val, design_val, case_val in scenarios:
        print(f"\n--- Running Scenario: defn = 2, n={n_val}, design={design_val}, case={case_val} ---")
        try:
            run_simulation(n_val, design_val, case_val)
            print(f"Finished Scenario: defn = 2, n={n_val}, design={design_val}, case={case_val}")
        except Exception as e:
            print(f"!!! Error in Scenario: defn = 2, n={n_val}, design={design_val}, case={case_val} !!!")
            print(f"Error details: {e}")
            continue # Continue to the next scenario even if one fails